import torch
from torch import nn
# from mmcls.models.backbones.rednet import RedNet
from mmcls.models.backbones.rednet import RedNet
# import mmclassification.mmcls.models.backbones.rednet
# import mmclassification
import numpy as np
from torch.nn import init
class redNet101(nn.Module):
    def __init__(self,num_classes=1):
        super(redNet101, self).__init__()
        self.model = RedNet(101)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # self.bn = nn.BatchNorm2d(3, eps=0, momentum=0.5, affine=True, track_running_stats=True)
        # self.bn1 = nn.BatchNorm2d(2048, eps=0, momentum=0.5, affine=True, track_running_stats=True)
        # self.eff = nn.Sequential()
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, num_classes)
        # self.bn = nn.BatchNorm2d(num_features=2048)

        init.normal_(self.fc1.weight, mean=0, std=0.01)
        init.normal_(self.fc2.weight, mean=0, std=0.01)
        init.normal_(self.fc3.weight, mean=0, std=0.01)
    def forward(self, x):
        # print('x', x.shape)#[32, 3, 224, 224]
        #pretraining code
        # for name, parms in self.model.named_parameters():
        #     print(name)
        #     print(parms.requires_grad)
        #     print(parms.grad)
        #     print(parms.data)
        # ref_fea  = self.model(self.bn(x_ref))
        # x_fea  = torch.where(torch.isnan(x_fea), torch.full_like(x_fea, 0), x_fea)
        # x_ref = torch.nn.BatchNorm2d(x, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # ref_fea = self.model(x_ref)
        # ref_fea = torch.where(torch.isnan(ref_fea), torch.full_like(ref_fea, 0), ref_fea)
        # temp = self.bn1(x_fea - ref_fea)
        # x_fea = self.bn(x_fea)
        x_fea_list = self.model(x) #[16, 2048, 7, 7]

        x_fea=x_fea_list[2]
        
        temp = self.avgpool(x_fea)
        temp = torch.flatten(temp, 1)
        q = torch.nn.functional.relu(self.fc1(temp))
        # q = torch.nn.functional.dropout(q)
        q = torch.nn.functional.relu(self.fc2(q))
        q = self.fc3(q)
        return q, x_fea_list , [self.fc1.weight ,self.fc2.weight ,self.fc3.weight]

    def init_weights(m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            nn.init.kaiming_normal_(m.weight)  # kaiming高斯初始化


class redNet50(nn.Module):
    def __init__(self,num_classes=1):
        super(redNet50, self).__init__()
        self.model = RedNet(50)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x):
        x_fea_list = self.model(x) #[16, 2048, 7, 7]
        x_fea=x_fea_list[2]
        temp = self.avgpool(x_fea)
        temp = torch.flatten(temp, 1)
        q = torch.nn.functional.relu(self.fc1(temp))
        q = torch.nn.functional.relu(self.fc2(q))
        q = self.fc3(q)
        return q, x_fea_list , [self.fc1.weight ,self.fc2.weight ,self.fc3.weight]

    def init_weights(m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            nn.init.kaiming_normal_(m.weight)  # kaiming高斯初始化

class redNet26(nn.Module):
    def __init__(self,num_classes=1):
        super(redNet26, self).__init__()
        self.model = RedNet(26)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x):
        x_fea_list = self.model(x) #[16, 2048, 7, 7]
        x_fea=x_fea_list[2]
        temp = self.avgpool(x_fea)
        temp = torch.flatten(temp, 1)
        q = torch.nn.functional.relu(self.fc1(temp))
        # q = torch.nn.functional.dropout(q)
        q = torch.nn.functional.relu(self.fc2(q))
        q = self.fc3(q)
        return q, x_fea_list , [self.fc1.weight ,self.fc2.weight ,self.fc3.weight]

    def init_weights(m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            nn.init.kaiming_normal_(m.weight)  # kaiming高斯初始化
